package api

import (
	"net/http"
	"strconv"
	"vulnmain/services"

	"github.com/gin-gonic/gin"
)

var vulnService = &services.VulnService{}

// CreateVuln 创建漏洞
func CreateVuln(c *gin.Context) {
	var req services.VulnCreateRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "参数错误: " + err.Error(),
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	vuln, err := vulnService.CreateVuln(&req, userID.(uint))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "创建成功",
		"data": vuln,
	})
}

// GetVuln 获取漏洞详情
func GetVuln(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	roleCode, exists := c.Get("role_code")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户角色信息缺失",
		})
		return
	}

	vuln, err := vulnService.GetVulnByID(uint(vulnID), userID.(uint), roleCode.(string))
	if err != nil {
		c.JSON(http.StatusNotFound, gin.H{
			"code": 404,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "获取成功",
		"data": vuln,
	})
}

// UpdateVuln 更新漏洞
func UpdateVuln(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	var req services.VulnUpdateRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "参数错误: " + err.Error(),
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	roleCode, exists := c.Get("role_code")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户角色信息缺失",
		})
		return
	}

	vuln, err := vulnService.UpdateVuln(uint(vulnID), &req, userID.(uint), roleCode.(string))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "更新成功",
		"data": vuln,
	})
}

// DeleteVuln 删除漏洞
func DeleteVuln(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	err = vulnService.DeleteVuln(uint(vulnID), userID.(uint))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "删除成功",
	})
}

// GetVulnList 获取漏洞列表
func GetVulnList(c *gin.Context) {
	var req services.VulnListRequest
	if err := c.ShouldBindQuery(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "参数错误: " + err.Error(),
		})
		return
	}

	// 获取当前用户信息进行权限控制
	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	roleCode, exists := c.Get("role_code")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户角色信息缺失",
		})
		return
	}

	// 设置权限控制字段
	req.CurrentUserID = userID.(uint)
	req.CurrentUserRole = roleCode.(string)

	response, err := vulnService.GetVulnList(&req)
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{
			"code": 500,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "获取成功",
		"data": response,
	})
}

// AuditVuln 审核漏洞
func AuditVuln(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	var req services.AuditRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "参数错误: " + err.Error(),
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	err = vulnService.AuditVuln(uint(vulnID), &req, userID.(uint))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "审核成功",
	})
}

// FixVuln 标记漏洞为已修复
func FixVuln(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	err = vulnService.FixVuln(uint(vulnID), userID.(uint))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "标记修复成功",
	})
}

// RetestVuln 复测漏洞
func RetestVuln(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	type RetestRequest struct {
		Result string `json:"result" binding:"required"`
	}

	var req RetestRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "参数错误: " + err.Error(),
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	err = vulnService.RetestVuln(uint(vulnID), req.Result, userID.(uint))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "复测成功",
	})
}

// AddVulnComment 添加漏洞评论
func AddVulnComment(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	type CommentRequest struct {
		Content string `json:"content" binding:"required"`
	}

	var req CommentRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "参数错误: " + err.Error(),
		})
		return
	}

	userID, exists := c.Get("user_id")
	if !exists {
		c.JSON(http.StatusUnauthorized, gin.H{
			"code": 401,
			"msg":  "用户未认证",
		})
		return
	}

	comment, err := vulnService.AddComment(uint(vulnID), req.Content, userID.(uint))
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "评论添加成功",
		"data": comment,
	})
}

// GetVulnStats 获取漏洞统计
func GetVulnStats(c *gin.Context) {
	stats, err := vulnService.GetVulnStats()
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{
			"code": 500,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "获取成功",
		"data": stats,
	})
}

// GetVulnTimeline 获取漏洞时间线
func GetVulnTimeline(c *gin.Context) {
	vulnIDStr := c.Param("id")
	vulnID, err := strconv.ParseUint(vulnIDStr, 10, 32)
	if err != nil {
		c.JSON(http.StatusBadRequest, gin.H{
			"code": 400,
			"msg":  "漏洞ID格式错误",
		})
		return
	}

	timeline, err := vulnService.GetVulnTimeline(uint(vulnID))
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{
			"code": 500,
			"msg":  err.Error(),
		})
		return
	}

	c.JSON(http.StatusOK, gin.H{
		"code": 200,
		"msg":  "获取成功",
		"data": timeline,
	})
}
